#For analysis of SERT assay cell-cell coupling strength.
#inputs to code are max-projected 3 channel images from confocal, 25X 1024x1024 cropped to 400x400 centered on HA positive cell.
#channel 1 is isolectin / channel 2 is HA / channel 3 is serotonin 


#%% Imports
import numpy as np
import matplotlib.pyplot as plt
import cv2 #successfully installed on my device
from skimage import io
from PIL import Image as img
import math
from scipy.spatial import distance

#%% Identify Boundaries of the SERT-HA-Positive Probe Cell

fig,ax = plt.subplots(layout="constrained")
plt.axis('off')

#take image file from leica confocal and split into grayscale within cv2 environment
img_stack = io.imread('msB_cell19_input.tif') # scikit-image opens multiple slices from a tiff stack by default
plt.imsave('intermediateIB4.tif', img_stack[:,:,0], format='tiff')
plt.imsave('intermediateHA.tif', img_stack[:,:,1], format='tiff')
ib4_img = cv2.imread('intermediateIB4.tif', cv2.IMREAD_GRAYSCALE)
ha_img = cv2.imread('intermediateHA.tif', cv2.IMREAD_GRAYSCALE)

#blur and threshold ha (sert+) channel
blur_ha = cv2.GaussianBlur(ha_img,(5,5),0)
ret_ha, thresh_ha = cv2.threshold(blur_ha,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)

#define contours of ha channel & identify the true ha+ cell (versus noise)
contours_ha, hierarchy_ha = cv2.findContours(thresh_ha, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
trueHA_contour = []
maxArea = 0 
for c in contours_ha:
    area = cv2.contourArea(c)
    if(area>maxArea):
        trueHA_contour.clear()
        trueHA_contour.append(c)
        maxArea = area

ha_img_clr = cv2.cvtColor(ha_img, cv2.COLOR_GRAY2RGB)

#test whether the code is working up to this point...
cv2.drawContours(ha_img_clr, trueHA_contour, -1, (0,255,0), 1)
plt.imsave('cellTest_HACellID.png', ha_img_clr) #save this data quality-control checkpoint - identified bounds of HA positive cell

#%% Identify Boundaries of the Vasculature in Image Contiguous With HA-Positive Cell

#blur and threshold isolectin (vasc.) channel 
blur_ib4 = cv2.GaussianBlur(ib4_img,(21,21),0)
ret_ib4, thresh_ib4 = cv2.threshold(blur_ib4,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)

#define contours of the vasculature
contours_ib4, hierarchy_ib4 = cv2.findContours(thresh_ib4, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

#identify the center point of the ha+ contour
#from https://www.geeksforgeeks.org/python-opencv-find-center-of-contour/
M = cv2.moments(trueHA_contour[0])
if M['m00'] != 0:
        cx = int(M['m10']/M['m00'])
        cy = int(M['m01']/M['m00'])
else:
    print("error - could not detect a center point for ha+ cell contour")

#identify the vasculature feature contiguous with the ha+ cell
#strategy - identify contour that contains the previously calculated center-of-mass of the ha object
vasc_contour = []
daughter_contour = []
loopFound = False
centerOutside = True #catch cases where no contour satisfies the requirement and center is slightly displaced 
allDist = np.zeros(len(contours_ib4))
k = 0
for c in contours_ib4:
    allDist[k] = cv2.pointPolygonTest(c, (cx,cy), True) #calculate distane from contour to point
    if(cv2.pointPolygonTest(c, (cx,cy), False)==1):
        vasc_contour = c
        centerOutside = False
        firstchild = (hierarchy_ib4[:,k][0])[2]
        if(firstchild !=-1): #if contour is nested, look to see if this constitutes a loop (i.e. capillary strcuture)
            curr_daughter_c = contours_ib4[firstchild]
            next_daughter = (hierarchy_ib4[:,firstchild][0])[0]
            daughtersLeft = True
            while(daughtersLeft):
                if(next_daughter==-1):
                    daughtersLeft = False
                if(cv2.contourArea(curr_daughter_c)>500): #if it larger than this value, it is likely a loop; otherwise it may just be a nuclear artifact
                    loopFound = True
                    break
                else:
                    curr_daughter_c = contours_ib4[next_daughter]
                    next_daughter = (hierarchy_ib4[:,next_daughter][0])[0]
        break
    k+=1

#catch the edge case where no contour is found and select the closest by distance
#i.e. center of mass for ha falls outside of vasculature; cell is curved
if(centerOutside):
    closest_index = int(np.where(allDist == allDist.max())[0])
    c = contours_ib4[closest_index]
    vasc_contour = c
    firstchild = (hierarchy_ib4[:,closest_index][0])[2]
    if(firstchild !=-1):
        curr_daughter_c = contours_ib4[firstchild]
        next_daughter = (hierarchy_ib4[:,firstchild][0])[0]
        daughtersLeft = True
        while(daughtersLeft):
            if(next_daughter==-1):
                daughtersLeft = False
            if(cv2.contourArea(curr_daughter_c)>500):
                loopFound = True
                break
            else:
                curr_daughter_c = contours_ib4[next_daughter]
                next_daughter = (hierarchy_ib4[:,next_daughter][0])[0]

#%% Generate an Image of Calculated Boundaries for QC

ib4_img_clr = cv2.cvtColor(ib4_img, cv2.COLOR_GRAY2RGB)
cv2.drawContours(ib4_img_clr, trueHA_contour, -1, (0,255,0), 1) #draw identfied boundaries of HA+ cell in GREEN
cv2.drawContours(ib4_img_clr, vasc_contour, -1, (255,0,0), 1) #draw identfied boundaries of relevant Vasculature in RED
if(loopFound):
    cv2.drawContours(ib4_img_clr, daughter_contour, -1, (0,0,255), 1) #if a looping feature is found, draw inner loop in BLUE
plt.imsave('cellTest_markUp_QC.png', ib4_img_clr)

#%% Creat Mask of HA-Positive Cell From Contours And Calculate Average Serotonin Intensity Within Cell
ser_vals = np.array(img_stack[:,:,2])
mask_HA = np.zeros(np.shape(ser_vals ))
total_5HT_inner = 0
for x in range(np.shape(mask_HA)[1]):
    for y in range(np.shape(mask_HA)[0]):
        #if outside the mask of the vasculature-of-interest, set value to zero 
        if(cv2.pointPolygonTest(trueHA_contour[0], (x,y), False)<1): #exclude edges
            mask_HA[y,x] = 0
        else:
            mask_HA[y,x] = 255
            total_5HT_inner += ser_vals[y,x]

#%% Create Mask of Vasculature Contiguous With HA-Positive Cell From Contours
mask_Vasc = np.zeros(np.shape(ser_vals ))
for x in range(np.shape(mask_Vasc)[1]):
    for y in range(np.shape(mask_Vasc)[0]):
        #if outside the mask of the vasculature-of-interest, set value to zero 
        if(cv2.pointPolygonTest(vasc_contour, (x,y), False)<1):
            mask_Vasc[y,x] = 0
        else:
            mask_Vasc[y,x] = 255

if(loopFound):
    for x in range(np.shape(mask_Vasc)[1]):
        for y in range(np.shape(mask_Vasc)[0]):
            #if outside the mask of the vasculature-of-interest, set value to zero 
            if(mask_Vasc[y,x] == 255):
                if(cv2.pointPolygonTest(curr_daughter_c, (x,y), False)==1):
                    mask_Vasc[y,x] = 0 #set anything within the daughter contour describing the loop to 0

#%% Refine Mask of Vasculature to Capture Pixels Within Arbitrary Range of the HA-Positive Cell

HA_coords = [] # store all coordinates included in HA mask within an array for referencing in subsequent code
for x in range(np.shape(mask_HA)[1]):
    for y in range(np.shape(mask_HA)[0]):
        if(mask_HA[y,x] == 255):
            HA_coords.append((x,y))

#First, remove the HA ROI from the vasculature mask
mask_Vasc_2 = mask_Vasc - mask_HA
mask_Vasc_2_img = img.fromarray(mask_Vasc_2).convert('LA')

#Next, identify a nearest subset of points
#cx / cy are center point - using this calculate all points within a certain interger value of A, where A is area of HA positive cell
V = np.where(mask_Vasc_2 == 255)[0].shape[0]
#create an array to store relevant data
dist2pt = np.zeros((V,4))
# only contains points demarcated as positive in the 'vasc2' mask - after HA subtract
# [i,0] = x coordinate
# [i,1] = y coordinate
# [i,2] = computed distance to cx,cy 
# [i,3] = value in serotonin channel
i = 0
for x in range(np.shape(mask_Vasc_2)[1]):
    for y in range(np.shape(mask_Vasc_2)[0]):
            if(mask_Vasc_2[y,x] == 255):
                dist2pt[i,0] = x
                dist2pt[i,1] = y
            
                #compare current point to every point in the HA mask
                currPt = [(x,y)]
                allDist = np.array(distance.cdist(currPt, HA_coords, 'euclidean'))
                dist2pt[i,2]  = np.min(allDist) #this is the shortest path; we don't need to know the selected point in the HA mask explicitly. 
                
                dist2pt[i,3] = ser_vals[y,x]
                
                i+=1
            
#re-sort array by distance metric
dist2pt_sorted = dist2pt[dist2pt[:,2].argsort()]

#identify closest indcies within some interger multiple of A
vascROI = np.zeros(np.shape(mask_Vasc_2))
total_5HT_outer = 0
A = np.where(mask_HA == 255)[0].shape[0] #size of the HA positive ROI
window = A*3 
for j in range(window): #choice of A*X is arbitrary
    x_j = dist2pt_sorted[j,0]
    y_j = dist2pt_sorted[j,1]
    vascROI[int(y_j),int(x_j)] = 255
    total_5HT_outer += dist2pt_sorted[j,3]

vascROI_img = img.fromarray(vascROI).convert('LA')
plt.imshow(vascROI_img) #check this image as a QC step for each cell run through the pipeline

#%% Compute a Background Value - Extravascular Serotonin Intensity

#using the thresholded vasculature winds up having a lot of contribution from vasculature just outside the bound
#the following code dilates vessel mask to avoid this issue when calculating background
blur_ib4_2 = cv2.GaussianBlur(thresh_ib4,(15,15),0)
ret_ib4_2, thresh_ib4_2 = cv2.threshold(blur_ib4_2,0,255,cv2.THRESH_BINARY)
#plt.imshow(thresh_ib4_2)

total_5HT_bg = 0
bg_pts = 0 
for x in range(np.shape(thresh_ib4_2)[1]):
    for y in range(np.shape(thresh_ib4_2)[0]):
            if(thresh_ib4_2[y,x] == 0):
                total_5HT_bg += ser_vals[y,x]
                bg_pts+=1
                
avg_bg_val = total_5HT_bg / bg_pts            

#%% Calculate Coupling Index 
avg_5HT_outer = total_5HT_outer / window
avg_5HT_inner = total_5HT_inner / A
corr_ratio = (avg_5HT_outer-avg_bg_val)/(avg_5HT_inner-avg_bg_val)
print("***Background-Corrected Coupling Index***")
print(round(corr_ratio,3))